#!/usr/bin/env python
# coding: utf-8

#Import libraries
import numpy as np
import random
import pandas as pd
from multiprocessing import Pool
import json
import sys
import os
import gc
import time
import multiprocessing
import logging

# Capture the command-line arguments passed in shell script
args = sys.argv[1:]

#Import Bash Variables
sample_name = args[0]
outdir = args[1]
fasta_filename = args[2]
contig_file = args[3]
desired_contig= args[4]
discard_deletions = args[5]
script_paths = args[6]

#Add script path to system
sys.path.append(script_paths)

#Import Worker function
from sampler_worker_script import worker_function, sampler_function

#Create Filepaths
path_to_text_file = outdir + "/" + sample_name + "/text.txt"
MonteCarloErrorLog = outdir + "/" + sample_name + "/MonteCarloError.log"
MonteCarloDataFrame = outdir + "/" + sample_name + "/MonteCarlo.txt"
aggregated_results = outdir + "/" + sample_name + "/MonteCarlo_Results.json"


#Define Sampling Variables
Num_Reps = 10
Sample_Size =1000000
max_concurrent_tasks = 9


# Define Functions

def analyze_fasta(fasta_filename, contig_ranges=None):
    """
    Analyze a FASTA file to identify and return positions of consecutive identical bases within the sequences.
    The function also filters these positions based on provided contig ranges and converts all sequences to uppercase.

    Parameters:
    fasta_filename (str): Path to the FASTA file.
    contig_ranges (dict, optional): Dictionary containing contig names as keys and tuples with range start and end positions as values.
                                    Positions outside these ranges are ignored. If None, no filtering is applied.

    Returns:
    dict: A dictionary containing sequence names as keys and lists of positions with consecutive identical bases as values.
    """

    sequences = {}  # Dictionary to store sequence data

    # Read the FASTA file, store sequences, and convert them to uppercase
    with open(fasta_filename, 'r') as file:
        current_sequence = None
        for line in file:
            line = line.strip()
            if line.startswith('>'):  # Identifier line
                current_sequence = line[1:]
                if current_sequence in sequences:
                    raise ValueError(f"Duplicate sequence identifier found: {current_sequence}")
                sequences[current_sequence] = ''
            elif current_sequence is not None:  # Sequence line
                # Convert the sequence to uppercase and accumulate
                sequences[current_sequence] += line.upper()
            else:
                raise ValueError("File format error: Data encountered before sequence identifier.")

    consecutive_positions_dict = {}  # Nested dictionary to store consecutive positions

    for sequence_name, sequence in sequences.items():
        sequence_positions = []
        prev_base = None
        current_streak = []  # Hold the positions of the current streak of identical bases

        # Processing each base in sequence to find consecutive identical bases
        for pos, base in enumerate(sequence, start=1):  # Start at position 1
            # No need for base.upper() here anymore as sequences are already in uppercase

            # If same base as previous, we continue the streak
            if base == prev_base:
                current_streak.append(pos)
            else:
                # If streak ends (change of base), and it had more than one base, we store the positions
                if len(current_streak) > 1:
                    sequence_positions.extend(current_streak)
                current_streak = [pos]  # New streak starts with the current base
                prev_base = base

        # Don't forget the last streak in the sequence
        if len(current_streak) > 1:
            sequence_positions.extend(current_streak)

        # If contig_ranges is provided, we filter the positions based on the range
        if contig_ranges and sequence_name in contig_ranges:
            range_start, range_end = contig_ranges[sequence_name]
            sequence_positions = [pos for pos in sequence_positions if range_start <= pos <= range_end]

        consecutive_positions_dict[sequence_name] = sequence_positions

    return consecutive_positions_dict

def clean_mutations(contig, bases, positions, contig_ranges, homopolymer_positions, desired_contig):
    """
    Cleans the read's mutations by removing deletions that are in homopolymer regions.
    Substitutions at these positions will not cause the read to be discarded.
    """

    # Validate contig
    if contig != desired_contig:
        return None  # Skip this read entirely

    # Determine the valid range for the desired contig (if required for further checks)
    range_start, range_end, contig_length = contig_ranges.get(desired_contig, (None, None))

    # Process mutations, excluding deletions at homopolymer positions
    valid_bases = []
    valid_positions = []
    for base, pos in zip(bases, positions):
        if base == 'Deletion' and pos in homopolymer_positions.get(contig, []):
            # Skip this mutation, but don't discard the entire read
            continue
        # Add the base and position to the list of valid mutations
        valid_bases.append(base)
        valid_positions.append(pos)

    # Return the cleaned mutations
    return (contig, ','.join(valid_bases), ','.join([str(pos) for pos in valid_positions]))




def convert_txt_to_dataframe(text_file_path, desired_contig, contig_ranges, homopolymer_positions, max_lines=None, discard_deletions=False):
    """
    Convert the text file from the sequencing reads to a pandas DataFrame, applying filters.

    Parameters:
    text_file_path (str): Path to the input text file.
    desired_contig (str): The contig value that records must match to be retained.
    contig_ranges (dict): Dictionary specifying valid ranges for contigs.
    homopolymer_positions (dict): Dictionary specifying positions of homopolymers.
    max_lines (int, optional): Maximum number of lines to process from the text file. If None, all lines are processed.

    Returns:
    pd.DataFrame: DataFrame containing filtered sequencing reads.
    """

    # Create a list to store the data batches
    data_list = []
    error_count = 0

    with open(text_file_path, 'r', encoding='utf-8') as text_file:
        for line_count, line in enumerate(text_file, start=1):
            if line_count == 1 and 'Positions' in line:
                continue
            if "Read without 'cs' tag found in No_Contig" in line:
                continue
            parts = line.strip().split('\t')

            if len(parts) < 9 or (discard_deletions and any(x in parts[2] for x in ['Deletion', 'Insertion', 'n'])):
                continue

            mutations = parts[2].split(',') if parts[2] else []
            position_values = parts[3].split(',') if parts[3] else []
            
            try:
                positions = [int(val) for val in position_values]
            except ValueError as e:
                error_count += 1
                if error_count > 5:
                    print("Too many errors encountered. Exiting loop for manual inspection.")
                    break
                continue

            if parts[1] == desired_contig:
                cleaned_data = clean_mutations(parts[1], mutations, positions, contig_ranges, homopolymer_positions, desired_contig)
                if cleaned_data:
                    cleaned_read = [parts[0]] + list(cleaned_data) + list(parts[4:])  # Excluding the second element
                    data_list.append(cleaned_read)

    try:
        df = pd.DataFrame(data_list, columns=['read_name', 'contig', 'bases', 'positions', 'read_len', 'query_sequence', 'tabulated', 'read_start', 'read_end'])
    except ValueError as e:
        print("Error creating DataFrame:", str(e))
        print("Sample records:", data_list[:10])
        # Return an empty DataFrame with the correct columns in case of an error
        df = pd.DataFrame(columns=['read_name', 'contig', 'bases', 'positions', 'read_len', 'query_sequence', 'tabulated', 'read_start', 'read_end'])

    return df



def main(results_file_path, sample_size):
    total_tasks_to_run = Num_Reps 
    tasks_run = 0

    managers = multiprocessing.Manager()
    return_dict = managers.dict()  # This dictionary will collect the results from each process.

    processes = []

    try:
        while tasks_run < total_tasks_to_run or processes:
            # Clean up the list of processes that have completed by removing them from the list
            processes = [proc for proc in processes if proc.is_alive()]

            # If we haven't reached the max number of tasks and still under the limit of concurrent tasks
            if tasks_run < total_tasks_to_run and len(processes) < max_concurrent_tasks:
                sampled_data = sampler_function(Full_DF, sample_size)  # your actual sampling function

                # Create a new process for analyzing
                new_process = multiprocessing.Process(target=worker_function, args=(tasks_run, contig_ranges, consecutive_positions, sampled_data, return_dict))
                processes.append(new_process)
                new_process.start()

                tasks_run += 1

                del sampled_data
                gc.collect()

                # Add a delay to avoid spawning all processes at once
                time.sleep(0.1)  # for instance, wait for 10 seconds before the next iteration

            # Sleep for a bit if needed to prevent this loop from consuming too much CPU
            time.sleep(0.1)

        # Wait for all processes to finish
        for proc in processes:
            proc.join()

    except Exception as e:
        print(f"An error occurred: {e}")

    # Convert the results from the manager dict to a regular list or dict as needed
    results = list(return_dict.values())  # or dict(return_dict)

    if results:
        try:
            # Further processing of results here
            extracted_data = [result.get_data() for result in results]  # modify based on your result structure
            normalized_results = normalize_data_by_wt(extracted_data, desired_contig)
            final_statistics = aggregate_statistics(normalized_results)
            save_results_to_file(final_statistics, results_file_path)
        except Exception as e:
            print(f"An error occurred during data processing: {e}")
    else:
        print("No successful results to process.")


def normalize_data_by_wt(extracted_data, target_contig):
    """
    Normalize the counts by the WT count for each contig and position.
    This function handles both deletions and substitutions.
    Only processes data for the specified target contig.
    """
    normalized_data = []

    for data in extracted_data:
        # Check if the target contig is in the data, skip if not.
        if target_contig not in data:
            print(f"Warning: {target_contig} not found in replicate. Skipping.")
            continue

        contig_data = data[target_contig]

        # Ensure 'WT' exists for the contig; if not, handle it as zero.
        wt_count = contig_data.get('WT', 0)  # Default to 0 if 'WT' not present

        replicate_normalized = {}

        # Normalize deletions
        normalized_deletions = {}
        for position, deletion_count in contig_data['deletions'].items():
            normalized_value = deletion_count / wt_count if wt_count > 0 else 0  # Avoid division by zero
            normalized_deletions[position] = normalized_value

        # Normalize substitutions
        normalized_substitutions = {}
        for position, substitution_count in contig_data['substitutions'].items():
            normalized_value = substitution_count / wt_count if wt_count > 0 else 0  # Avoid division by zero
            normalized_substitutions[position] = normalized_value

        replicate_normalized = {
            'normalized_deletions': normalized_deletions,
            'normalized_substitutions': normalized_substitutions,
        }

        normalized_data.append(replicate_normalized)

    return normalized_data




def aggregate_statistics(normalized_data):
    """
    Calculate the mean and standard deviation for each position across all replicates.
    This function now assumes that all data points in normalized_data belong to the target contig.
    """
    aggregated_stats = {
        'deletions': {},
        'substitutions': {}
    }

    # Aggregate statistics by type and position.
    for replicate_data in normalized_data:
        # Process deletions
        for position, normalized_value in replicate_data['normalized_deletions'].items():
            if position not in aggregated_stats['deletions']:
                aggregated_stats['deletions'][position] = []
            aggregated_stats['deletions'][position].append(normalized_value)

        # Process substitutions
        for position, normalized_value in replicate_data['normalized_substitutions'].items():
            if position not in aggregated_stats['substitutions']:
                aggregated_stats['substitutions'][position] = []
            aggregated_stats['substitutions'][position].append(normalized_value)

    # Calculate mean and standard deviation for deletions
    deletion_stats = {}
    for position, values in aggregated_stats['deletions'].items():
        mean_value = sum(values) / len(values)
        std_dev = (sum([(mean_value - x) ** 2 for x in values]) / len(values)) ** 0.5
        deletion_stats[position] = {
            'mean': mean_value,
            'std_dev': std_dev
        }

    # Calculate mean and standard deviation for substitutions
    substitution_stats = {}
    for position, values in aggregated_stats['substitutions'].items():
        mean_value = sum(values) / len(values)
        std_dev = (sum([(mean_value - x) ** 2 for x in values]) / len(values)) ** 0.5
        substitution_stats[position] = {
            'mean': mean_value,
            'std_dev': std_dev
        }

    final_statistics = {
        'deletions': deletion_stats,
        'substitutions': substitution_stats
    }

    return final_statistics


def save_results_to_file(data, filename):
    """
    Save the data to a JSON file.

    Args:
    data (dict): The data to save.
    filename (str): The name of the file.
    """
    with open(filename, 'w') as file:
        json.dump(data, file, indent=4)



def setup(contig_file, fasta_filename, path_to_text_file, desired_contig):
    global contig_ranges, consecutive_positions, Full_DF

    # Process Contig_Ranges
    contig_ranges = {}

    with open(contig_file, 'r') as file:
        for line in file:
            line = line.strip()
            if line:
                parts = line.split()
                contig_name = parts[0]
                start_position = int(parts[1])
                end_position = int(parts[2])
                contig_length = int(parts[3])
                contig_ranges[contig_name] = (start_position, end_position, contig_length)

    # Execute analyze_fasta function
    consecutive_positions = analyze_fasta(fasta_filename)

    # Import, Filter, and Produce Full dataframe to sample from
    Full_DF = convert_txt_to_dataframe(path_to_text_file, desired_contig, contig_ranges, consecutive_positions, None, discard_deletions)
    Full_DF.to_csv(MonteCarloDataFrame, sep='\t', index=False)

    print("Setup completed. Monte Carlo Simulation beginning")


########### Execute Code ###########

#Execute main function
if __name__ == "__main__":
    setup(contig_file, fasta_filename, path_to_text_file, desired_contig)

    # Set up logging
    logging.basicConfig(filename=MonteCarloErrorLog, level=logging.INFO, format='%(asctime)s:%(levelname)s:%(message)s')
    logger = logging.getLogger(__name__)

    # The path needs to be known or received from somewhere, perhaps as a command-line argument.
    path_to_save_results = aggregated_results
    sample_size = Sample_Size
    main(path_to_save_results, sample_size)




